Score Matching for Density Estimation

A first post using Quarto

Optimization
Score Matching
Julia
Score matching is a method for indirectly estimating the probability density function of a distribution. In this post, I will explain the score matching method as well as some of its limitations.
Author

Simon Ghyselincks

Published

May 22, 2024

Big Idea

Denoising Autoencoders (DAE) are a type of machine learning model that is trained to reconstruct the input data from a noisy or corrupted version of the input. The DAE is trained to take an sample such as an image with unwanted noise and restore it to the original sample.

In the process of learning the denoising parameters, the DAE also can learn the score function the underlying distribution of noisy samples, which is a kernel density estimate of the true distribution.

The score function is an operator defined as: \[ s(f(x)) = \nabla_x \log f(x) \]

Where \(f(x)\) is the density function or PDF of the distribution.

By learning a score function for a model, we can reverse the score operation to obtain the original density function it was derived from. This is the idea behind score matching, where we indirectly find the the pdf of a distribution by matching the score of a proposed model \(p(x;\theta)\) to the score of the true distribution \(q(x)\).

Another benefit of learning the score function of a distribution is that it can be used to move from less probable regions of the distribution to more probable regions using gradient ascent. This is useful when it comes to generative models, where we want to generate new samples from the distribution that are more probable.

However one of the challenges is that the score function is not always well-defined, especially in regions of low probability where there are sparse samples. This can make it difficult to learn the score function accurately in these regions.

This post explores some of those limitations and how increasing the bandwidth of the noise kernel in the DAE can help to stabilize the score function in regions of low probability.

Sample of Score Matching

Suppose we have a distribution in 2D space that consists of three Gaussians as our ground truth. We can plot this pdf and its gradient field.

Show the code
using Plots, Distributions

# Define the ground truth distribution
function p(x, y)
    mu1, mu2, mu3 = [-1, -1], [1, 1], [1, -1]
    sigma1, sigma2, sigma3 = [0.5 0.3; 0.3 0.5], [0.5 0.3; 0.3 0.5], [0.5 0; 0 0.5]

    return 0.2 * pdf(MvNormal(mu1, sigma1), [x, y]) + 0.2 * pdf(MvNormal(mu2, sigma2), [x, y]) + 0.6 * pdf(MvNormal(mu3, sigma3), [x, y])
end

# Plot the distribution using a heatmap
heatmap(
    -3:0.01:3, -3:0.01:3, p,
    c=cgrad(:davos, rev=true),
    aspect_ratio=:equal,
    xlabel="x", ylabel="y", title="Ground Truth PDF q(x)",
    xlims=(-3, 3), ylims=(-3, 3),
    xticks=[-3, 3], yticks=[-3, 3]
)

Sampling from the distribution can be done by generating 100 random points

Show the code
using Plots, Distributions

# Define the ground truth distribution
function p(x, y)
    mu1, mu2, mu3 = [-1, -1], [1, 1], [1, -1]
    sigma1, sigma2, sigma3 = [0.5 0.3; 0.3 0.5], [0.5 0.3; 0.3 0.5], [0.5 0; 0 0.5]

    return 0.2 * pdf(MvNormal(mu1, sigma1), [x, y]) + 0.2 * pdf(MvNormal(mu2, sigma2), [x, y]) + 0.6 * pdf(MvNormal(mu3, sigma3), [x, y])
end

# Sample 200 points from the ground truth distribution
n_points = 200
points = []

while length(points) < n_points
    x = rand() * 6 - 3
    y = rand() * 6 - 3
    if rand() < p(x, y)
        push!(points, (x, y))
    end
end

# Plot the distribution using a heatmap
# heatmap(
#     -3:0.01:3, -3:0.01:3, p,
#     c=cgrad(:davos, rev=true),
#     aspect_ratio=:equal,
#     xlabel="x", ylabel="y", title="Ground Truth PDF q(θ)",

# )

# Scatter plot of the sampled points
scatter([x for (x, y) in points], [y for (x, y) in points], label="Sampled Points", color=:red, ms=2,
     xlims=(-3, 3), ylims=(-3, 3),
     xticks=[-3, 3], yticks=[-3, 3])

From this sampling of points we can visualize the effect of the choice of noise bandwidth on the kernel density estimate.

Show the code
using Plots, Distributions, ForwardDiff

# Define the ground truth distribution
function p(x, y)
    mu1, mu2, mu3 = [-1, -1], [1, 1], [1, -1]
    sigma1, sigma2, sigma3 = [0.5 0.3; 0.3 0.5], [0.5 0.3; 0.3 0.5], [0.5 0; 0 0.5]

    return 0.2 * pdf(MvNormal(mu1, sigma1), [x, y]) + 0.2 * pdf(MvNormal(mu2, sigma2), [x, y]) + 0.6 * pdf(MvNormal(mu3, sigma3), [x, y])
end

# Define the log of the distribution
function log_p(x, y)
    val = p(x, y)
    return val > 0 ? log(val) : -Inf
end

# Function to compute the gradient using ForwardDiff
function gradient_log_p(u, v)
    grad = ForwardDiff.gradient(x -> log_p(x[1], x[2]), [u, v])
    return grad[1], grad[2]
end

# Generate a grid of points
xs = -3:0.5:3
ys = -3:0.5:3

# Create meshgrid manually
xxs = [x for x in xs, y in ys]
yys = [y for x in xs, y in ys]

# Compute the gradients at each point
U = []
V = []
for x in xs
    for y in ys
        u, v = gradient_log_p(x, y)

        push!(U, u)
        push!(V, v)
    end
end

# Convert U and V to arrays
U = reshape(U, length(xs), length(ys))
V = reshape(V, length(xs), length(ys))

# Plot the distribution using a heatmap
heatmap(
    -3:0.01:3, -3:0.01:3, p,
    c=cgrad(:davos, rev=true),
    aspect_ratio=:equal,
    xlabel="x", ylabel="y", title="Ground Truth PDF q(x) with score",
    xlims=(-3, 3), ylims=(-3, 3),
    xticks=[-3, 3], yticks=[-3, 3]
)

# Flatten the gradients and positions for quiver plot
xxs_flat = [x for x in xs for y in ys]
yys_flat = [y for x in xs for y in ys]

# Plot the vector field
quiver!(xxs_flat, yys_flat, quiver=(vec(U)/20, vec(V)/20), color=:green, quiverkeyscale=0.5)

Now we apply a Gaussian kernel to the sample points to create the kernel density estimate:

Show the code
using Plots, Distributions, KernelDensity

# Convert points to x and y vectors
x_points = [x for (x, y) in points]
y_points = [y for (x, y) in points]

# Perform kernel density estimation using KernelDensity.jl
parzen = kde((y_points, x_points); boundary=((-3,3),(-3,3)), bandwidth = (.3,.3))

# Plot the ground truth PDF
p1 = heatmap(
    -3:0.01:3, -3:0.01:3, p,
    c=cgrad(:davos, rev=true),
    aspect_ratio=:equal,
    xlabel="x", ylabel="y", title="Ground Truth PDF q(x)",
    xlims=(-3, 3), ylims=(-3, 3),
    xticks=[-3, 3], yticks=[-3, 3]
)

# Scatter plot of the sampled points on top of the ground truth PDF
scatter!(p1, x_points, y_points, label="Sampled Points", color=:red, ms=2)


# Plot the kernel density estimate
p2 = heatmap(
    parzen.x, parzen.y, parzen.density,
    c=cgrad(:davos, rev=true),
    aspect_ratio=:equal,
    xlabel="x", ylabel="y", title="Kernel Density Estimate",
    xlims=(-3, 3), ylims=(-3, 3),
    xticks=[-3, 3], yticks=[-3, 3]
)

# Scatter plot of the sampled points on top of the kernel density estimate
scatter!(p2, x_points,  y_points, label="Sampled Points", color=:red, ms=2)

# Arrange the plots side by side
plot(p1, p2, layout = @layout([a b]), size=(800, 400))

Now looking at the density estimate across many bandwidths, we can see the effect on adding more and more noise to the original sampled points and our density estimate that we are learning. At very large bandwidths the estimate becomes a uniform distribution.

Show the code
using Plots, Distributions, KernelDensity
# Define the range of bandwidths for the animation
bandwidths = [(0.01 + 0.05 * i, 0.01 + 0.05 * i) for i in 0:40]

# Create the animation
anim = @animate for bw in bandwidths
    kde_result = kde((x_points,y_points); boundary=((-6, 6), (-6, 6)), bandwidth=bw)

    p2 = heatmap(
        kde_result.x, kde_result.y, kde_result.density',
        c=cgrad(:davos, rev=true),
        aspect_ratio=:equal,
        xlabel="x", ylabel="y", title="Kernel Density Estimate,Bandwidth = $(round(bw[1],digits=2))",
        xlims=(-6, 6), ylims=(-6, 6),
        xticks=[-6, 6], yticks=[-6, 6]
    )

    scatter!(p2, x_points, y_points, label="Sampled Points", color=:red, ms=2)
end

# Save the animation as a GIF
gif(anim, "parzen_density_animation_with_gradients.gif", fps=2,show_msg = false)

Now we can compute the score of the kernel density estimate to see how it changes with the bandwidth. The score function of the distribution is numerically unstable at regions of sparse data. Recalling that the score is the gradient of the log-density funtion, when the density is very low the function approaches negative infinity. Within the limits of numerical precision, taking the log of the density function will result in a negative infinity in sparse and low probability regions. Higher bandwidths of KDE using the Gaussian kernel for example, spread out both the discrete sampling and the true distribution over space. This extends the region of numerical stability for a higher bandwidth.

The regions with poor numerical stability can be seen as noise artifacts and missing data in the partial derivatives of the log-density function. Some of these artifacts may also propogate from the fourier transform calculations that the kernel density estimate uses.

Show the code
using Plots, Distributions, KernelDensity, ForwardDiff

# Define the range of bandwidths for the animation
bandwidths = [(0.01 + 0.05 * i, 0.01 + 0.05 * i) for i in 0:30]

boundary = (-10, 10)
# Create the animation
anim = @animate for bw in bandwidths
    kde_result = kde((x_points, y_points); boundary=(boundary, boundary), bandwidth=bw)

        # Compute log-density
    log_density = log.(kde_result.density)

    # Compute gradients of log-density
    grad_x = zeros(size(log_density))
    grad_y = zeros(size(log_density))

    # Compute gradients using finite difference centered difference
    for i in 2:size(log_density, 1)-1
        for j in 2:size(log_density, 2)-1
            grad_x[i, j] = (log_density[i+1, j] - log_density[i-1, j]) / (kde_result.x[i+1] - kde_result.x[i-1])
            grad_y[i, j] = (log_density[i, j+1] - log_density[i, j-1]) / (kde_result.y[j+1] - kde_result.y[j-1])
        end
    end
    # Downsample the gradients and coordinates by selecting every 10th point
    downsample_indices_x = 1:10:size(grad_x, 1)
    downsample_indices_y = 1:10:size(grad_y, 2)

    grad_x_downsampled = grad_x[downsample_indices_x, downsample_indices_y]
    grad_y_downsampled = grad_y[downsample_indices_x, downsample_indices_y]

    x_downsampled = kde_result.x[downsample_indices_x]
    y_downsampled = kde_result.y[downsample_indices_y]

    xxs_flat = repeat(x_downsampled, inner=[length(y_downsampled)])
    yys_flat = repeat(y_downsampled, outer=[length(x_downsampled)])

    grad_x_flat = grad_x_downsampled[:]
    grad_y_flat = grad_y_downsampled[:]

    # Plot heatmaps of the gradients
    p1 = heatmap(
        kde_result.x, kde_result.y, grad_x',
        c=cgrad(:davos, rev=true),
        aspect_ratio=:equal,
        xlabel="x", ylabel="y", title="Partial Derivative of Log-Density wrt x \n Bandwidth = $(round(bw[1],digits=2))",
        xlims=boundary, ylims=boundary
    )

    # Overlay the scatter plot of the sampled points
    scatter!(p1, x_points, y_points, label="Sampled Points", color=:red, ms=2)

    p2 = heatmap(
        kde_result.x, kde_result.y, grad_y',
        c=cgrad(:davos, rev=true),
        aspect_ratio=:equal,
        xlabel="x", ylabel="y", title="Partial Derivative of Log-Density wrt y \n Bandwidth = $(round(bw[1],digits=2))",
        xlims=boundary, ylims=boundary
    )

    # Overlay the scatter plot of the sampled points
    scatter!(p2, x_points, y_points, label="Sampled Points", color=:red, ms=2)

    plot(p1, p2, layout = @layout([a b]), size=(800, 400))
end
# Save the animation as a GIF
gif(anim, "parzen_density_partials.gif", fps=2, show_msg=false)

And combining the gradient overtop of the ground truth distribution that is modeled with the kernel density estimate, starting with the larger bandwidths and moving to the smaller bandwidths, we can see that the region of numerical stability is extended with the larger bandwidths. The larger bandwidths also remove some of the precision in the model, with larger bandwidths the model approaches a single gaussian distribution.

Show the code
# Define the range of bandwidths for the animation
bandwidths = [(0.01 + 0.2 * i, 0.01 + 0.2 * i) for i in 0:10]
bandwidths = reverse(bandwidths)

boundary = (-10, 10)
# Create the animation
anim = @animate for bw in bandwidths
    kde_result = kde((x_points, y_points); boundary=(boundary, boundary), bandwidth=bw)

    # Compute log-density
    log_density = log.(kde_result.density)

    # Compute gradients of log-density
    grad_x = zeros(size(log_density))
    grad_y = zeros(size(log_density))

    # Compute gradients using finite difference centered difference
    for i in 2:size(log_density, 1)-1
        for j in 2:size(log_density, 2)-1
            grad_x[i, j] = (log_density[i+1, j] - log_density[i-1, j]) / (kde_result.x[i+1] - kde_result.x[i-1])
            grad_y[i, j] = (log_density[i, j+1] - log_density[i, j-1]) / (kde_result.y[j+1] - kde_result.y[j-1])
        end
    end
    # Downsample the gradients and coordinates by selecting every 10th point
    downsample_indices_x = 1:20:size(grad_x, 1)
    downsample_indices_y = 1:20:size(grad_y, 2)

    grad_x_downsampled = grad_x[downsample_indices_x, downsample_indices_y]
    grad_y_downsampled = grad_y[downsample_indices_x, downsample_indices_y]

    x_downsampled = kde_result.x[downsample_indices_x]
    y_downsampled = kde_result.y[downsample_indices_y]

    xxs_flat = repeat(x_downsampled, inner=[length(y_downsampled)])
    yys_flat = repeat(y_downsampled, outer=[length(x_downsampled)])

    grad_x_flat = grad_x_downsampled[:]
    grad_y_flat = grad_y_downsampled[:]

     # Plot the actual distribution
    x_range = boundary[1]:0.01:boundary[2]
    y_range = boundary[1]:0.01:boundary[2]
    p1 = heatmap(
        x_range, y_range, p,
        c=cgrad(:davos, rev=true),
        aspect_ratio=:equal,
        xlabel="x", ylabel="y", title="Ground Truth PDF q(x)\n with score of Kernel Density Estimate, \n Bandwidth = $(round(bw[1],digits=2))",
        xlims=boundary, ylims=boundary,
        size=(800, 800)
    )

    # Plot a quiver plot of the downsampled gradients
    quiver!(yys_flat, xxs_flat, quiver=(grad_x_flat/10, grad_y_flat/10), 
    color=:green, quiverkeyscale=0.5, aspect_ratio=:equal)
end
# Save the animation as a GIF
gif(anim, "parzen_density_gradient_animation_with_gradients.gif", fps=2, show_msg=false)